Skip to content

[Fmha] Sparse MLA decode kernel selection heuristics#2836

Merged
bkryu merged 4 commits intoflashinfer-ai:mainfrom
PerkzZheng:perkzz/sparse-mla-perf
Mar 20, 2026
Merged

[Fmha] Sparse MLA decode kernel selection heuristics#2836
bkryu merged 4 commits intoflashinfer-ai:mainfrom
PerkzZheng:perkzz/sparse-mla-perf

Conversation

@PerkzZheng
Copy link
Contributor

@PerkzZheng PerkzZheng commented Mar 20, 2026

📌 Description

Summary

Improve kernel-selection heuristic in FlashInfer's trtllm-gen FMHA runner (SM100/SM103 sparse MLA decode), and
cleans up the kernel selection loop in run().

Changes

selectSparseMlaGenerationKernel() (new function)

Separates sparse MLA selection from the generic generation path with
tuned per-config heuristics:

  • numHeadsQ=64 (KeepsMmaAb, tileSizeQ=64) — previously selected
    SwapsMmaAb (tileSizeQ=16); this is the main perf win (see below).
  • numHeadsQ=128 (KeepsMmaAb) — batch-aware 1CTA/2CTA selection:
    use 2CTA when batchSize * numCtasPerToken * 8 > MP (avoids
    under-utilizing SM at small batch).
  • numHeadsQ≤32 (SwapsMmaAb) — batch-aware tileSizeQ halving at
    batch=1: use tileSizeQ/2 when batchSize * maxNumCtasPerSeqKv ≤ MP/8
    (doubles head-splitting parallelism when GPU is under-utilized).
  • CgaSmemReduction guard scoped to MLA kernels only: suppress for
    tileSizeQ≥32 (headDimQk=576 exceeds smem budget); non-MLA paths
    unaffected.

run() refactoring

  • Replaces the unbounded while (mSelectNewKernel) loop with a bounded
    for loop (kMaxKernelSelectionPasses=4), making convergence
    explicit and verifiable.
  • Extracts buildLaunchConfig() and setNonPortableClusterIfNeeded()
    as private helpers to eliminate duplicated inline setup.
  • Replaces !mSelectNewKernel guards on mMultiCtasKvMode assignments
    in selectSparseMlaGenerationKernel with an isGmemReduction() check,
    preserving Disabled/CgaSmemReduction modes set by
    computeCtaAndClusterConfig on re-entry.

Artifact

Updates TRTLLM_GEN_FMHA cubin checksum for new KeepsMmaAb sparse
variants required by the numHeadsQ=64 path.


Benchmark — Sparse MLA decode speedup (SM100, DeepSeek-V3 config)

GPU: B200, topK=128, seqlen=[1k–32k], dtype=[bf16, fp8]

numHeadsQ batch Speedup range
64 32 1.35–1.72x
64 128 2.17–2.74x
64 512 2.23–2.50x
32 128 1.24–1.59x
32 512 1.28–1.50x
16/128 all ~1.0x (neutral, path unchanged)

The numHeadsQ=64 improvement (1.35–2.74x) comes entirely from switching
to KeepsMmaAb/tileSizeQ=64 (previously SwapsMmaAb/tileSizeQ=16).

Tests

  • test_trtllm_gen_mla.py: 4512/4512 passed (3168 skipped, SM100 only)
  • test_trtllm_gen_attention.py: 20832/20832 passed (23544 skipped)

🔍 Related Issues

#2797

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Fixed FMHA kernel selection logic to use bounded iteration instead of unbounded loop, improving reliability and preventing potential deadlocks.
    • Updated internal artifact references and checksums for TRT-LLM FMHA kernels.
  • New Features

    • Added kernel parameter configuration option for shared paged KV index handling.

…lashinfer-ai#885

Port heuristic improvements from trtllm-gen MR flashinfer-ai#885 to FlashInfer's
trtllm-gen FMHA kernel selection (SM100/SM103 sparse MLA decode):

1. New `selectSparseMlaGenerationKernel()` separates sparse MLA selection
   from the non-sparse path with tuned heuristics:
   - numHeadsQ <= 32 (SwapsMmaAb): batch-aware tileSizeQ halving at
     batch=1 (2x more head-splitting CTAs when GPU is under-utilized);
     threshold batchSize * maxNumCtasPerSeqKv <= MP/8.
   - numHeadsQ >= 64 (KeepsMmaAb): batch-aware 1CTA/2CTA selection for
     numHeadsQ=128; threshold batchSize * numCtasPerToken * 8 > MP.
   - numHeadsQ=64 now uses KeepsMmaAb (tileSizeQ=64) instead of SwapsMmaAb
     (tileSizeQ=16), yielding 1.35-2.74x speedup across all batch sizes.

2. CgaSmemReduction guard scoped to MLA kernels: only suppress for
   tileSizeQ >= 32 (headDimQk=576 exceeds smem limit); non-MLA kernels
   are unaffected.

3. Fix kernel re-selection deadlocks by guarding mMultiCtasKvMode
   assignments with !mSelectNewKernel:
   - SwapsMmaAb path: preserve CgaSmemReduction upgrade on re-selection.
   - KeepsMmaAb path: preserve Disabled mode when numCtasPerSeqKv==1
     (small topK), avoiding infinite loop.

Also add benchmark script: benchmarks/bench_sparse_mla.py sweeping
batch=[1,32,128,512], seqlenKv=[1k-32k], numHeadsQ=[16,32,64,128],
dtype=[bf16,e4m3] for DeepSeek-V3 sparse MLA config.

Update TRTLLM_GEN_FMHA cubin artifact hash to e7afc4134b (new kernels
required for numHeadsQ=64 1CTA KeepsMmaAb sparse variants).

Validated: 4512/4512 test_trtllm_gen_mla.py tests pass (3168 skipped).
Replace unbounded while loop in run() with a bounded for loop
(kMaxKernelSelectionPasses=4) to make termination explicit and
verifiable. Each re-selection trigger fires at most once, so the
sequence always converges within 3 passes in practice.

Also refactor selectSparseMlaGenerationKernel to remove the
!mSelectNewKernel guard pattern on mMultiCtasKvMode assignments,
replacing it with an isGmemReduction() check that is semantically
equivalent but clearer: it preserves any Disabled mode set by
computeCtaAndClusterConfig on re-entry instead of unconditionally
overwriting it.

Extract buildLaunchConfig() and setNonPortableClusterIfNeeded() as
private helpers to eliminate repeated inline setup in run().

Verified no regression across 4512 MLA + 20832 GQA/context tests,
and benchmarked GQA (headsQPerKv=4/8/16) and dense MLA decode across
batch=[1,4,16,64,256,512], seq=[1024,4096,16384], dtype=[bf16,fp8]:
max absolute delta <10us at all reliably-measurable latencies.

Add benchmarks/bench_decode_regression.py covering GQA and MLA decode
with the above sweep grid for future regression comparisons.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

This PR updates TRTLLM FMHA kernel artifacts and refactors kernel selection logic. Changes include replacing an unbounded while loop with a bounded for loop (max 4 passes), restructuring CgaSmemReduction fallback handling, and adjusting MLA kernel selection heuristics. A new kernel parameter field mUsesSharedPagedKvIdx is added.

Changes

Cohort / File(s) Summary
Artifact Updates
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_FMHA and corresponding CheckSumHash.TRTLLM_GEN_FMHA to new versioned path and checksum values.
Kernel Selection Refactoring
include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Refactored TllmGenFmhaKernel::run from unbounded while loop to bounded for loop (4 passes max). Added buildLaunchConfig and setNonPortableClusterIfNeeded helpers. Restructured CgaSmemReduction fallback handling to occur post-selection. Adjusted CgaSmemReduction eligibility gating for MLA kernels. Reworked MLA heuristics by extracting sparse path to selectSparseMlaGenerationKernel with updated tile/CTA selection criteria.
Kernel Parameters
include/flashinfer/trtllm/fmha/kernelParams.h
Added boolean field mUsesSharedPagedKvIdx to KernelParams struct, initialized to true by default.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested labels

ready, v0.6.5

Suggested reviewers

  • joker-eph
  • aleozlx
  • cyx-6
  • yzh119
  • nvmbreughe
  • bkryu

Poem

🐰 A kernel hops through bounded loops so fine,
With CgaSmemReduction fallbacks by design,
MLA heuristics split and refined with care,
Artifacts updated with checksums fair—
FMHA kernels dance, now steady and spry! ✨

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: improving sparse MLA decode kernel selection heuristics in the FMHA runner.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed The pull request description follows the template structure with all major sections completed: Summary, Changes, Benchmark results, Tests, Related Issues, and Pre-commit/Test checklists all properly filled out.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on enhancing the performance and stability of sparse MLA decode operations within the FlashInfer framework. It introduces refined kernel selection heuristics, improves code structure, and updates necessary artifacts to ensure compatibility and optimal performance on SM100/SM103 architectures. The changes aim to address under-utilization issues and prevent potential shared memory overruns, resulting in significant speedups for specific configurations.

Highlights

  • Kernel Selection Heuristics: Improves kernel selection heuristics for sparse MLA decode in the trtllm-gen FMHA runner, specifically targeting SM100/SM103 architectures.
  • Performance Optimization: Optimizes performance by selecting the KeepsMmaAb kernel with tileSizeQ=64 for numHeadsQ=64, resulting in significant speedups (1.35x-2.74x) in certain configurations.
  • Code Refactoring: Refactors the run() function to use a bounded for loop for kernel selection and extracts helper functions to improve code readability and maintainability.
  • CGA Smem Reduction: Adds a guard for CgaSmemReduction scoped to MLA kernels to prevent shared memory budget overruns for tileSizeQ≥32.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@PerkzZheng PerkzZheng force-pushed the perkzz/sparse-mla-perf branch from e36f412 to e2b5819 Compare March 20, 2026 05:37
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance improvements for sparse MLA decode on SM100/SM103 architectures by porting and refining kernel selection heuristics. The introduction of selectSparseMlaGenerationKernel cleanly separates the new logic, and the heuristics for different head counts and batch sizes are well-documented and appear effective based on the provided benchmarks.

The refactoring in run() is a major improvement for code quality and safety, replacing an unbounded while loop with a bounded for loop and extracting helper functions to improve readability. The added benchmark script is also a valuable contribution for future performance tracking.

Overall, this is a high-quality contribution that improves both performance and maintainability. I have one minor suggestion for cleanup in the new benchmark script.

I am having trouble creating individual review comments. Click here to see my feedback.

benchmarks/bench_sparse_mla.py (242)

medium

The done variable, initialized on line 207, is incremented here but its value is never used. The progress is already tracked and printed using len(results) and total. This variable can be removed for code clarity.

@bkryu bkryu added the run-ci label Mar 20, 2026
@bkryu
Copy link
Collaborator

bkryu commented Mar 20, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !437 has been created, and the CI pipeline #46582171 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

517-518: Remove extraneous semicolon.

There's a stray semicolon on what would be line 518 (after the maxNumCtasPerSeqKv declaration).

🧹 Proposed fix
     int const maxNumCtasPerSeqKv = flashinfer::ceil_div(params.mMaxSeqLenKv, 256);
-    ;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 517 - 518,
Remove the stray semicolon after the declaration of maxNumCtasPerSeqKv; locate
the line that declares "int const maxNumCtasPerSeqKv =
flashinfer::ceil_div(params.mMaxSeqLenKv, 256);" and delete the following
extraneous ";" so the statement is a single clean declaration using
flashinfer::ceil_div and params.mMaxSeqLenKv.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 517-518: Remove the stray semicolon after the declaration of
maxNumCtasPerSeqKv; locate the line that declares "int const maxNumCtasPerSeqKv
= flashinfer::ceil_div(params.mMaxSeqLenKv, 256);" and delete the following
extraneous ";" so the statement is a single clean declaration using
flashinfer::ceil_div and params.mMaxSeqLenKv.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1a076edc-d3d2-4022-84be-35991e6c7ab1

📥 Commits

Reviewing files that changed from the base of the PR and between ad893cf and e2b5819.

📒 Files selected for processing (3)
  • flashinfer/artifacts.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/kernelParams.h

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #46582171: 13/20 passed

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both CIs look good to me. Thanks @PerkzZheng

@bkryu bkryu merged commit c72f62a into flashinfer-ai:main Mar 20, 2026
45 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants